Skip to content

Add PER_TOKEN_HEAD FP8 quant and P-scale to batch_prefill#7883

Open
msaffari-amd wants to merge 2 commits into
developfrom
users/msaffari-amd/ck/add_per_head_per_token_quant_fmha
Open

Add PER_TOKEN_HEAD FP8 quant and P-scale to batch_prefill#7883
msaffari-amd wants to merge 2 commits into
developfrom
users/msaffari-amd/ck/add_per_head_per_token_quant_fmha

Conversation

@msaffari-amd
Copy link
Copy Markdown
Contributor

Adds a new FP8 quantization scheme (PER_TOKEN_HEAD) to the CK batch_prefill FMHA kernel, along with optional per-query-head P-scale support.

Motivation

Existing FP8 quant modes (PERTENSOR, KV_BLOCKSCALE) applies descaling that doesn't capture per-token or per-head variance in activation magnitudes. PER_TOKEN_HEAD enables descaling for Q and K at per-token-per-head granularity.

Technical Details

Quantization scheme
Tensor Descale granularity Shape
Q per-token, per-head [total_q, nhead_q]
K per-token, per-head (paged) [num_total_pages, page_block_size, nhead_k]
V per-head [nhead_k]
The QK dequantization (s_acc[i,j] *= q_descale[i] * k_descale[j]) is staged through LDS to minimize inner-loop register pressure matching the approach used in the fmha_fwd pipeline.

P-scale
An optional per-q-head P-scale [num_head_q] is supported. log2(p_scale) is folded into the exp2 row-max shift, so the scale factor appears in both P and the rowsum l, cancelling in O = sum(P·V) / l without needing a separate fixup.

Cross-page support
Unlike KV_BLOCKSCALE (which requires page_block_size >= kN0), PER_TOKEN_HEAD supports cross-page tiles by precomputing per-column physical page indices. This enables page_size=64 (newly added to the codegen list).

Add a new FP8 quantization scheme (PER_TOKEN_HEAD, enum value 5) for the
batch_prefill FMHA kernel. Unlike PERTENSOR (single scale for all of Q/K/V)
or KV_BLOCKSCALE (per-page K/V scales), PER_TOKEN_HEAD applies fine-grained
descales:

  - Q descale: per-token, per-head  [total_q, nhead_q]
  - K descale: per-token, per-head  [num_total_pages, page_block_size, nhead_k]
  - V descale: per-head             [nhead_k]

The dequantization of the QK dot product is staged through LDS to avoid
inflating the inner-loop instruction footprint. Cross-page tiles
(page_block_size < kN0) are supported via per-column physical page lookup,
unlike KV_BLOCKSCALE which requires page_block_size >= kN0.

Additionally, an optional per-q-head P-scale [num_head_q] is supported.
The kernel folds log2(p_scale) into the exp2 row-max shift, so the scale
factor appears in both P and the rowsum l, cancelling in O = sum(P*V) / l
with no separate V-descale fixup needed.

Also adds page_size=64 to the codegen page size list, and includes SRD
same-page-skip optimizations for K/V window rebasing.

Changes:
  - block_attention_quant_scale_enum.hpp: PER_TOKEN_HEAD = 5
  - quant.hpp: enum, serialize ("pth"), decode
  - cpp_symbol_map.py: codegen symbol mappings
  - fmha_batch_prefill.py: page_size=64, per_token_head qscale, filter update
  - fmha_fwd.hpp: args struct (stride fields, p_scale_ptr), kargs forwarding
  - fmha_batch_prefill_kernel.hpp: kargs struct, MakeKargs, get_scale_s,
    pipeline dispatch
  - block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: LDS-staged dequant,
    p_scale_log2 exp2-shift fold, cross-page support, SRD same-page skip,
    PER_TOKEN_HEAD convenience overload
New decode-aligned KV cache layout for FP8 PER_TOKEN_HEAD batch_prefill:
5D vectorized K + 4D ColumnMajor V [NumBlocks, NumHeads, HeadDim, PageSize].
Matches the layout produced by reshape_and_cache and consumed by the decode
paged-attention kernel, so prefill can ingest the live KV cache without an
intermediate reshape.

- block_attention_kvcache_layout_enum.hpp: add VEC_K_COL_V_LAYOUT (= 2).
- fmha_batch_prefill_kernel.hpp: route K dram through the vectorized branch
  for VEC_K_COL_V; add a new V dram branch building (Pages, HeadDim,
  PageSize) with stride (batch_stride_v, page_block_size, 1) and merging to
  logical (D, TotalSeqK).
- block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: keep kAlignmentV =
  kMaxVecLoad for VEC_K_COL_V despite kPadSeqLenK=true (full-page invariant
  keeps vec loads safe along the contiguous PageSize dim).
- block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp: new
  kUseVectorizedVPolicy<Problem>() predicate routes VEC_K_COL_V through
  the same V tile dist / LDS layout / SmemKPack / alignment as
  VECTORIZED_LAYOUT.
- block_fmha_pipeline_problem.hpp: relax IsVLayoutRowMajor static_assert to
  accept ColumnMajor V for VEC_K_COL_V; introduce kIsKVectorized predicate
  so the page_size=1 + K-vectorized rejection covers the new layout.
- tile_fmha_traits.hpp: extend the supported-layouts static_assert.
- fmha_fwd.hpp: add `bool is_v_rowmajor = true` to fmha_batch_prefill_args
  so the wrapper can flip it for VEC_K_COL_V.
- codegen/ops/fmha_batch_prefill.py: add SUPPORTED_KV_MEMORY_LAYOUT_FP8_PTH_EXTRA
  map entry and a gated emission loop for fp8bf16 PER_TOKEN_HEAD with
  vlayout="col" + kv_memory_layout="vec_k_col_v" across both lookup tables.
  Relax receipt 200 to allow vlayout="col" only when kv_memory_layout
  == "vec_k_col_v".
@amd-yashagar amd-yashagar force-pushed the users/msaffari-amd/ck/add_per_head_per_token_quant_fmha branch from 20e4392 to 475d8ae Compare May 29, 2026 16:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant